#' Show Group Enrichment Result
#'
#' See [group_enrichment] for examples.
#' NOTE the box fill and the box text have different meanings.
#'
#' @inheritParams ggplot2::facet_wrap
#' @param df_enrich result `data.frame` from [group_enrichment].
#' @param return_list if `TRUE`, return a list of `ggplot` object so user
#' can combine multiple plots by other R packages like `patchwork`.
#' @param add_text_annotation if `TRUE`, add text annotation in box.
#' When show p value with filled color, the text indicates relative change;
#' when show relative change with filled color, the text indicates p value.
#' @param fill_by_p_value if `TRUE`, show log10 based p values with filled color.
#' The +/- of p values indicates change direction.
#' @param use_fdr if `TRUE`, show FDR values instead of raw p-values.
#' @param cut_p_value if `TRUE`, cut p values into 5 regions for better visualization.
#' Only works when `fill_by_p_value = TRUE`.
#' @param cut_breaks when `cut_p_value` is `TRUE`, this option set the (log10 based) breaks.
#' @param cut_labels when `cut_p_value` is `TRUE`, this option set the labels.
#' @param fill_scale a `Scale` object generated by `ggplot2` package to
#' set color for continuous values.
#' @param cluster_row if `TRUE`, cluster rows with Hierarchical Clustering ('complete' method).
#' @param ... other parameters passing to [ggplot2::facet_wrap], only used
#' when `return_list` is `FALSE`.
#'
#' @return a (list of) `ggplot` object.
#' @export
show_group_enrichment <- function(df_enrich,
                                  return_list = FALSE,
                                  scales = "free",
                                  add_text_annotation = TRUE,
                                  fill_by_p_value = TRUE,
                                  use_fdr = TRUE,
                                  cut_p_value = FALSE,
                                  cut_breaks = c(-Inf, -5, log10(0.05), -log10(0.05), 5, Inf),
                                  cut_labels = c("\u2193 1e-5", "\u2193 0.05", "non-significant", "\u2191 0.05", "\u2191 1e-5"),
                                  fill_scale = scale_fill_gradient2(
                                    low = "#08A76B", mid = "white", high = "red",
                                    midpoint = ifelse(fill_by_p_value, 0, 1)
                                  ),
                                  cluster_row = FALSE,
                                  ...) {
  if (fill_by_p_value) {
    df_enrich$p_value_up <- if (use_fdr) abs(log10(df_enrich$fdr)) else abs(log10(df_enrich$p_value))
    df_enrich$p_value_up <- data.table::fifelse(
      df_enrich$measure_observed >= 1,
      df_enrich$p_value_up,
      -df_enrich$p_value_up
    )
  }

  if (return_list) {
    df_enrich %>%
      dplyr::group_nest(.data$grp_var) %>%
      dplyr::mutate(
        gg = purrr::map(.data$data,
          plot_enrichment_simple,
          x = "enrich_var", y = "grp1",
          fill_scale = fill_scale,
          fill_by_p_value = fill_by_p_value,
          cut_p_value = cut_p_value,
          cut_breaks = cut_breaks,
          cut_labels = cut_labels,
          add_text_annotation = add_text_annotation,
          use_fdr = use_fdr,
          cluster_row = cluster_row
        )
      ) -> xx
    p <- xx$gg
    names(p) <- xx$grp_var
  } else {
    p <- plot_enrichment_simple(df_enrich,
      x = "enrich_var", y = "grp1",
      fill_scale = fill_scale,
      fill_by_p_value = fill_by_p_value,
      cut_p_value = cut_p_value,
      cut_breaks = cut_breaks,
      cut_labels = cut_labels,
      add_text_annotation = add_text_annotation,
      use_fdr = use_fdr,
      cluster_row = cluster_row
    ) +
      facet_wrap(~grp_var, scales = scales, ...)
  }

  return(p)
}

plot_enrichment_simple <- function(data, x, y, fill_scale,
                                   fill_by_p_value = TRUE,
                                   cut_p_value = FALSE,
                                   cut_breaks = c(-Inf, -10, -1.3, 1.3, 10, Inf),
                                   cut_labels = c("< -10", "< -1.3", "nosig", "> 1.3", "> 10"),
                                   add_text_annotation = TRUE,
                                   use_fdr = TRUE,
                                   cluster_row = FALSE) {
  if (fill_by_p_value) {
    data$measure_observed <- round(data$measure_observed, 2)
  } else {
    if (use_fdr) {
      data$fdr <- round(data$fdr, 3)
    } else {
      data$p_value <- round(data$p_value, 3)
    }
  }

  if (cut_p_value) {
    data$p_value_up <- cut(data$p_value_up,
      breaks = cut_breaks,
      labels = cut_labels,
    )
  }

  # 支持行聚类(subgroup)
  if (isTRUE(cluster_row)) {
    data2 <- data[, c(x, y, "grp_var", "measure_observed"), with = F]
    data2 <- tidyr::pivot_wider(data2, names_from = x, values_from = "measure_observed")

    get_cluster_order <- function(x) {
      x <- x %>%
        tibble::column_to_rownames("grp1")
      obj <- x %>%
        scale() %>%
        stats::dist() %>%
        stats::hclust() %>%
        stats::as.dendrogram()
      rownames(x)[stats::order.dendrogram(obj)]
    }
    orders <- data2 %>%
      dplyr::group_split(.data$grp_var, .keep = FALSE) %>%
      purrr::map(get_cluster_order) %>%
      purrr::reduce(c) %>%
      unique()
    message("All subgroup orders: ", paste(orders, collapse = ", "))
    data$grp1 <- factor(data$grp1, levels = orders)
  }

  p <- ggplot(
    data,
    aes_string(
      x = x,
      y = y
    )
  )

  if (cut_p_value) {
    p <- p +
      geom_tile(mapping = aes_string(fill = "p_value_up")) +
      scale_fill_manual(
        drop = FALSE,
        na.value = "grey",
        values = c("#08A76B", "#98FF97", "white", "orange", "red")
      )
  } else {
    p <- p +
      geom_tile(mapping = aes_string(fill = if (fill_by_p_value) "p_value_up" else "measure_observed")) +
      fill_scale
  }

  legend_label <- if (fill_by_p_value && use_fdr) {
    "FDR"
  } else if (fill_by_p_value && !use_fdr) {
    "P-value"
  } else {
    "FC"
  }
  if (!cut_p_value) legend_label <- paste0("log10\n(", legend_label, ")")
  p <- p +
    labs(
      x = "Variable",
      y = "Subgroup",
      fill = legend_label
    ) +
    scale_x_discrete(expand = expansion(mult = c(0, 0))) +
    scale_y_discrete(expand = expansion(mult = c(0, 0)))

  if (add_text_annotation) {
    p <- p +
      geom_text(
        mapping = aes_string(
          label = if (fill_by_p_value) {
            "measure_observed"
          } else if (use_fdr) {
            "fdr"
          } else {
            "p_value"
          }
        ),
        size = 3
      )
  }

  p
}
